spec: support MTP#6
Conversation
|
After the refactoring, all the state management of the draft context is perform outside of diff --git a/common/speculative.cpp b/common/speculative.cpp
index ef13edd34..95329b8a6 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -592,19 +592,6 @@ struct common_speculative_state_mtp : public common_speculative_impl {
auto & draft_tokens = *dp.result;
draft_tokens.clear();
- if (last_n_drafted[seq_id] > 0) {
- const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
- if (n_to_drop > 0) {
- const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
- if (pos_max >= 0) {
- const llama_pos drop_from = pos_max - n_to_drop + 1;
- llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
- }
- }
- last_n_drafted[seq_id] = 0;
- last_n_accepted[seq_id] = 0;
- }
-
// Effective draft length: min(global cap, per-sequence override).
int32_t n_max = std::max(1, params.n_max);
if (dp.n_max > 0) {
@@ -673,32 +660,9 @@ struct common_speculative_state_mtp : public common_speculative_impl {
cond_tok = best;
++pos;
}
-
- last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
- GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
-
- auto * ctx_dft = this->params.ctx_dft;
-
- const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
- const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id];
-
- const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
-
- if (pos_max < 0) {
- last_n_accepted[seq_id] = (int32_t) n_accepted;
- return;
- }
-
- if (n_to_drop > 0) {
- const llama_pos drop_from = pos_max - n_to_drop + 1;
- llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
- }
-
- last_n_drafted [seq_id] = 0;
- last_n_accepted[seq_id] = (int32_t) n_accepted;
}
};
|
|
Give me ~1 hour an I'll open a PR here to simplify (wip: https://github.com/ggml-org/llama.cpp/tree/gg/spec-mtp-experiments) |
In the partial rollback implementation, the accepted batch is not re-evaluated with the draft context, correct? I think this will narrow the difference a bit, though not very sure by how much. Here are the
|
|
on my DGX spark (patched with adding a draft acceptance loop)
|
|
Another thing is |
|
Basically at low acceptance rates < 0.5, the speed difference is going to much larger. From anecdotal usage, using this PR I seem to even hit 9 toks/sec when doing real coding work, vs with partial rollback I never hit below 14 toks/sec even when acceptance is low. You can perhaps try and use it, I felt the difference is quite real. |
|
Did you use this branch or #7 ? |
|
I used this branch, just saw #7 |
|
Just tried #7 as well, Qwen3.6 27B - "wall_s_total": 100.33 Somehow acceptance rates are suspiciously high, maybe some accounting error For reference in |
With the |
|
You can observe the accepted drafts with |
|
|
Yes, I'm also not sure. On Mac it is always useful for some reason. On CUDA sometimes it helps sometimes not. In any case, it can be adjusted with the Regarding the partial rollback - it does bring a noticeable benefit on CUDA. But I still don't see a good way to support it cleanly. Among other drawbacks, the compute graph is also no longer static. The logic is not compatible with ngram speculative decoding because it uses long drafts of ~64 tokens which still need to be checkpointed. And for some reason that I still don't understand, it does not seem to help much on Mac. |
| // TODO: how to make it work with vision tokens? | ||
| if (batch_in.token == nullptr || batch_in.embd != nullptr) { | ||
| pending_pos[seq_id] = -1; | ||
| return true; | ||
| } |
There was a problem hiding this comment.
I'm not really sure what is the correct way to process the image embeddings with the MTP context. In any case, vision MTP seems to already work to good extent:
Here I ask it to OCR 100 random integers without speculative decoding and with MTP:
- Without spec decoding
- With MTP
With MTP it is ~2x faster which means the MTP context "knows" about the integers in some way. But at the same time, I'm pretty sure that the current way of processing is not 100% correct because inp->tokens tensor in the mtp graph is being used with stale data when the input batch has image embeddings and no tokens.
I think we will figure this out later - not super important atm.
|
@am17an I think the changes are good overall. On my end, I will continue on top of ggml-org#22838 to support specifying multiple speculative decoding types like this: --spec-type ngram-mod,mtpShould be simple change and when ready, will proceed with merging ggml-org#22838. |
We can perhaps just enable this option when MTP is enabled as a spec mode for hybrid models, I think we can also make the compute graph static by only doing rollback when |
|
We can iterate on it, but I don't think we can merge MTP directly with the partial rollback changes. These changes have to be in a follow-up PR because they affect a lot of logic: ggml, llama.cpp recurrent state, server logic, backend code. We have to merge something that is solid and works across all hardware, so we can in parallel continue to add other speculative decoding approaches. The partial rollback will be a potential optimization if we figure out how to do it cleanly. |
|
Yes agreed, for other models it is not even required so first we should get MTP in master and make it stable. As such there are issues with GGUF loading/unloading and general memory-issues that are needed to be fixed. I will keep the partial rollback branch up-to-date so people are free to use it. So the plan is that you merge ggml-org#22838, and then I rebase github.com/ggml-org/pull/22673 on top on that with the changes here. And then we can probably have another round of review regarding the other parts of the code? |
Ok sounds good.
Yes. I haven't looked at all at the prompt prefill yet so not sure what is the status there. I think this branch here should perform a bit better thanks to pinned mem. The GGUF loading is probably the most important to figure out how to make it user friendly. |
|
commenting just on the gguf mtp approach, as a user I believe it would be best to align with the same packaging principles as mmproj and other spec decoding implementations (eagle, dflash etc) - keep optional model features in their own external gguf for maximum flexibility at runtime. |
|
--split-mode tensors become invalid and affect MTP speed.Removing However, previously |
|
The argument is |
|
|
Sorry if someone already talked about this, but:
Flags used to reproduce llama-server checkpoint crash (fixed with Logs of llama-server crashing with checkpointsBtw, thanks ggerganov and am17an for the insane work being done here :] |
|
If you increase the |
|
|
This branch starts with |
Still fails to load with
Also, I think the main thing is that it crashes with checkpoints, mainly this line (with |
Ah, I missed that there was such logic. In that case, it is better to wait for the updated branch that will likely include this logic again. Here we are mainly prototyping the speculative architecture. |
|
for testing i load it like so ./build/bin/llama-server this works |
|
@ggerganov I think now the main PR is in a relatively good state |
|
Ok thanks. I realized there might be an issue with tensor parallel support - not sure if the device copies of the checkpoints are handled properly with multi-GPU tensor splits. Will look into this today and see if there is some fix - might need to get some feedback from Johannes. |
|
hrm. I did not see benefits on strix halo vulkan I am realizing this is known for moe? I was running unsloth/Qwen3.5-35B-A3B-GGUF:UD-Q4_K_XL typical 47-50 tps long context. around 57tps short context or ~ 47-50 30k context. on mtp fork I didn't see an improvement.
unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL ~ got 47-50 tps long context.
I really didn't see a speedup at all. I tried am17an's model and saw around 11-18tps I also ran -np 1 and tried -fixx 2048 and --ctx-checkpoints 0 on all these. in general I saw draft acceptance rate = 1.00000. it was for sure working. After reading this thread I used gg-mtp-rebase and had similar results. I have experienced a few crashes or "pauses". Where output just stops, or doesn't happen at all. I did not run any of this long term to experience more. my first pause/crash..... each subsequent request logged srv params_from_: Chat format: peg-native. I was using llama webui. after reading this thread I also retested am17an/Qwen3.6-27B-MTP-GGUF on gg-mtb-rebase without/with mtp and saw ~7tps go to ~10tps. Being a non MOE I can confirm moar speed? acceptance seemed more realistic to my brain as well, being less than 1 :) I'm available to test multi-cuda or vulkan stuff if needed/wanted.. I just navigated a bunch of pr and discussion and decided to leave my experience here in case it's helpful. |
|
I've been playing with the MTP PR. since the mtp pr doesn't add support for llama-bench, I've been testing with a prompt: With the PR built yesterday, commit: e7b4848 Looks the the PR was recently rebased on master, since then I can only run it with spec-draft-n-max=0! new PR commit: Something like the following: |
I have removed the partial rollback changes and isolated changes for just qwen models. Things to work out
n_seq> 1note that partial rollback is extremely important for the speed-up here, for the MoE model there is actually a slowdown with MTP on this branch